import numpy as np
import torch
from rlcore.algo import JointPPO
from rlagent import Neo
from mpnn import MPNN
from manual import INVADER_Control
from utils import make_multiagent_env


def Get_ENV_Learner(args, env=None, return_env=False):
    process_index = 0 if return_env else -1
    if env is None:
        env = make_multiagent_env(  args.env_name, 
                                    num_agents=args.num_agents, 
                                    dist_threshold=args.dist_threshold,
                                    arena_size=args.arena_size, identity_size=args.identity_size, process_index = process_index)
    policy1 = None
    policy2 = None
    team1 = []
    team2 = []

    num_adversary = 0
    num_friendly = 0
    for i,agent in enumerate(env.world.policy_agents):
        if hasattr(agent, 'adversary') and agent.adversary:
            num_adversary += 1
        else:
            num_friendly += 1

    # share a common policy in a team
    action_space = env.action_space[i]
    entity_mp = args.entity_mp  # enable entity message passing
    if args.env_name == 'simple_spread':
        num_entities = args.num_agents
    elif args.env_name == 'simple_formation':
        num_entities = 1
    elif args.env_name == 'simple_line':
        num_entities = 2
    elif args.env_name == 'simple_tag':
        num_entities = 2
    elif args.env_name == 'simple_tag2':
        num_entities = 2
    elif args.env_name == 'simple_tag3':
        num_entities = 2
    elif args.env_name == 'simple_tag4':
        num_entities = 2
    elif args.env_name == 'simple_tag5':
        num_entities = 2
    elif args.env_name == 'simple_tag6':
        num_entities = 2
    elif args.env_name == 'simple_formation_3':
        num_entities = 2
    elif args.env_name == 'simple_formation_4':
        num_entities = 2
    elif args.env_name == 'simple_formation_6':
        num_entities = 2
    elif args.env_name == 'simple_tag8':
        num_entities = 2
    elif args.env_name == 'simple_tag9':
        num_entities = 2
    elif args.env_name == 'simple_tag10':
        num_entities = 2
    elif args.env_name == 'simple_tag14':
        num_entities = 2
    elif args.env_name == 'simple_tag15':
        num_entities = 2
    elif args.env_name == 'simple_tag16':
        num_entities = 2
    elif args.env_name == 'hunter_invader':
        num_entities = 2
    else:
        raise NotImplementedError('Unknown environment, define entity_mp for this!')    

    if entity_mp:
        pol_obs_dim = env.observation_space[i].shape[0] - 2*num_entities
    else:
        pol_obs_dim = env.observation_space[i].shape[0]
    
    
    assert args.env_name == "hunter_invader"    # 埋桩
    # index at which agent's position is present in its observation
    pos_index = args.identity_size + 2
    for i, agent in enumerate(env.world.policy_agents): # 为每个智能体设置策略？ 
        obs_dim = env.observation_space[i].shape[0]
        #  invader的策略
        if hasattr(agent, 'adversary') and agent.adversary:
            if policy1 is None:
                policy1 = INVADER_Control(input_size=pol_obs_dim,num_agents=num_adversary,num_entities=num_entities,action_space=action_space,
                               pos_index=pos_index, mask_dist=args.mask_dist,entity_mp=entity_mp, device=args.device).to(args.device)
            team1.append(Neo(args,policy1,(obs_dim,),action_space))
        else:
        #  hunter的策略
            if policy2 is None:
                policy2 = MPNN(input_size=pol_obs_dim,num_agents=num_friendly,num_entities=num_entities,action_space=action_space,
                               pos_index=pos_index, mask_dist=args.mask_dist,entity_mp=entity_mp, device=args.device).to(args.device)
            team2.append(Neo(args,policy2,(obs_dim,),action_space))
    
    learner = EnvLearner(args, [team1, team2], [policy1, policy2], env)

    if args.continue_training:
        print("Loading pretrained model")
        learner.load_models(torch.load(args.load_dir)['models'])

    if return_env:
        return learner, env
    return learner


class EnvLearner(object):
    # supports centralized training of agents in a team
    def __init__(self, args, teams_list, policies_list, env):
        self.teams_list = [x for x in teams_list if len(x)!=0]
        self.all_agents = [agent for team in teams_list for agent in team]
        self.policies_list = [x for x in policies_list if x is not None]
        self.trainers_list = []

        for simple_policy in self.policies_list:
            if len(list(simple_policy.parameters())) == 0:
                self.trainers_list.append(None)
            else:
                self.trainers_list.append(
                    JointPPO(simple_policy, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef,
                                        args.entropy_coef, lr=args.lr, max_grad_norm=args.max_grad_norm,
                                        use_clipped_value_loss=args.clipped_value_loss) 
                )

        #print(self.trainers_list)
        self.device = args.device
        self.env = env

    @property
    def all_policies(self):
        return [agent.actor_critic.state_dict() for agent in self.all_agents]

    @property
    def team_attn(self):
        return self.policies_list[0].attn_mat

    def initialize_obs(self, obs):
        # obs - num_processes x num_agents x obs_dim
        #print(obs)
        #print(len(obs))
        for i, agent in enumerate(self.all_agents):
            agent.initialize_obs(torch.from_numpy(obs[:,i,:]).float().to(self.device))
            agent.rollouts.to(self.device)

    def act(self, step):
        actions_list = []
        for team, policy in zip(self.teams_list, self.policies_list):
            # concatenate all inputs
            # step 用于提取轨迹中的一个点？

            # agent.rollouts.obs[step]
            # （129=时间步,64=线程数,31=obs维度）

            all_obs = torch.cat([agent.rollouts.obs[step] for agent in team])
            # （线程数*智能体数=第一维度 a1e1,a1e2,a1e3,...,a2e1,a2e2,a2e3,....  ,
            #   obs维度 = 第二维度）


            all_hidden = torch.cat([agent.rollouts.recurrent_hidden_states[step] for agent in team])
            all_masks = torch.cat([agent.rollouts.masks[step] for agent in team])

            props = policy.act(all_obs, all_hidden, all_masks, deterministic=False) # a single forward pass
            # value,action,action_log_probs,state
            # props[1] = action list = torch.Size([128=线程数*智能体数, 1=整数标量动作0,1,2,3,4])


            # split all outputs
            n = len(team)

            all_value, all_action, all_action_log_prob, all_states = [torch.chunk(x, n) for x in props]
            for i in range(n):
                team[i].value = all_value[i]
                team[i].action = all_action[i]
                team[i].action_log_prob = all_action_log_prob[i]
                team[i].states = all_states[i]
                actions_list.append(all_action[i].cpu().numpy())

        return actions_list

    def update(self):
        return_vals = []
        # use joint ppo for training each team
        # self.trainers_list:::JointPPO( policy of a team)
        for i, trainer in enumerate(self.trainers_list):
            rollouts_list = [agent.rollouts for agent in self.teams_list[i]]
            if trainer is not None: 
                vals = trainer.update(rollouts_list)
                return_vals.append([np.array(vals)]*len(rollouts_list))

        return np.stack([x for v in return_vals for x in v]).reshape(-1,3)

    def wrap_horizon(self):
        for team, policy in zip(self.teams_list,self.policies_list):
            last_obs = torch.cat([agent.rollouts.obs[-1] for agent in team])
            last_hidden = torch.cat([agent.rollouts.recurrent_hidden_states[-1] for agent in team])
            last_masks = torch.cat([agent.rollouts.masks[-1] for agent in team])

            with torch.no_grad():
                next_value = policy.get_value(last_obs, last_hidden, last_masks)

            all_value = torch.chunk(next_value,len(team))
            for i in range(len(team)):
                team[i].wrap_horizon(all_value[i])

    def after_update(self):
        for agent in self.all_agents:
            agent.after_update()
    def update_rollout(self, obs, reward, masks):
        obs_t = torch.from_numpy(obs).float().to(self.device)
        for i, agent in enumerate(self.all_agents):
            agent_obs = obs_t[:, i, :]  # agent obs in all threads
            agent.update_rollout(agent_obs, reward[:,i].unsqueeze(1), masks[:,i].unsqueeze(1))

    def load_models(self, policies_list):
        for agent, policy in zip(self.all_agents, policies_list):
            agent.load_model(policy)

    def eval_act(self, obs, recurrent_hidden_states, mask):
        # used only while evaluating policies. Assuming that agents are in order of team!
        ALPHA_swarm_obs = []
        BETA_swarm_obs = []
        all_obs = []
        # with index i
        for i in range(len(obs)):
            # find agent by index
            agent = self.env.world.policy_agents[i]
            # 前面的 num_adversaries 是 Invader，剩下的是 Hunter， Invader is adversary
            if hasattr(agent, 'adversary') and agent.adversary:
                ALPHA_swarm_obs.append( 
                    torch.as_tensor(
                        obs[i],dtype=torch.float,device=self.device
                        ).view(1,-1)
                    )
            else:
                BETA_swarm_obs.append(
                    torch.as_tensor(
                        obs[i],dtype=torch.float,device=self.device
                        ).view(1,-1)
                    )
        if len(ALPHA_swarm_obs)!=0:
            all_obs.append(ALPHA_swarm_obs)

        if len(BETA_swarm_obs)!=0:
            all_obs.append(BETA_swarm_obs)

        actions = []
        # learner = EnvLearner(args, [team1, team2], [policy1, policy2], env)

        for team,policy,obs in zip(self.teams_list,self.policies_list,all_obs):
            if len(obs)!=0:
                # torch.cat(obs) 是列表， 该队伍 各个智能体 观测向量 的列表
                _,action,_,_ = policy.act(torch.cat(obs).to(self.device),None,None,deterministic=True)
                actions.append(action.squeeze(1).cpu().numpy())

        return np.hstack(actions)

    def set_eval_mode(self):
        for agent in self.all_agents:
            agent.actor_critic.eval()

    def set_train_mode(self):
        for agent in self.all_agents:
            agent.actor_critic.train()
